# -*- coding: utf-8 -*-
# File: ops.py
# Pytorch custom operations

import torch
import torch.nn.functional as F

__all__ = ['crop_or_pad_as']


def crop_or_pad_as(input: torch.Tensor, other: torch.Tensor, *,
                   pad_val: float = 0.0) -> torch.Tensor:
    """Crop or pad the input to make the shape same as other. Assuming the last
    two dimensions are height and width.
    Arguments:
    input (Tensor): A pytorch tensor with shape [..., h_i, w_i].
    other (Tensor): A pytorch tensor with shape [..., h_o, w_o].
    pad_val (float): A float number to determine the padding value if padding
    is necessary.
    Returns:
    A tensor with shape input.shape[:-2]+other.shape[-2:], generated by
    cropping the input at center or padding the input.
    """
    h_i, w_i = input.shape[-2:]
    h_o, w_o = other.shape[-2:]
    # Do padding if necessary.
    h_pad_up = max(0, (h_o-h_i)//2)
    h_pad_down = max(0, h_o-h_i-h_pad_up)
    w_pad_left = max(0, (w_o-w_i)//2)
    w_pad_right = max(0, w_o-w_i-w_pad_left)
    pad_size = [w_pad_left, w_pad_right, h_pad_up, h_pad_down]
    if max(pad_size) > 0:
        input = F.pad(input, pad_size, value=pad_val)
    # Do cropping if necessary.
    h_crop_up = max(0, (h_i-h_o)//2)
    h_crop_down = h_crop_up+h_o
    w_crop_left = max(0, (w_i-w_o)//2)
    w_crop_right = w_crop_left+w_o
    input = input[..., h_crop_up:h_crop_down, w_crop_left:w_crop_right]
    return input
